{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from sklearn.datasets import make_classification\n",
    "from sklearn.model_selection import train_test_split\n",
    "from sklearn.metrics import accuracy_score, f1_score\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "from pytorch_tabular.utils import make_mixed_dataset, print_metrics\n",
    "%load_ext autoreload\n",
    "%autoreload 2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "# Load dataset\n",
    "data, cat_col_names, num_col_names = make_mixed_dataset(task=\"classification\", n_samples=10000, n_features=8, n_categories=4, weights=[0.8], random_state=42)\n",
    "\n",
    "# Create a new, second target\n",
    "data['second_target'] = 0\n",
    "for c in cat_col_names:\n",
    "    data.second_target += data[c]\n",
    "\n",
    "# Correlate it to 1st target\n",
    "data.second_target += (data.target == 'class_0')\n",
    "\n",
    "# Create random discrete noise to make task non-trivial\n",
    "random_noise = np.random.normal(0, 0.8, data.shape[0]).astype(int)\n",
    "data.second_target = (data.second_target + random_noise).mod(3) \n",
    "\n",
    "# Now that dataset is complete, we can perform train/test split\n",
    "train, test = train_test_split(data, random_state=42)\n",
    "train, val = train_test_split(train, random_state=42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "target\n",
       "class_0    0.7968\n",
       "class_1    0.2032\n",
       "Name: proportion, dtype: float64"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.target.value_counts(normalize=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "second_target\n",
       "2.0    0.3364\n",
       "1.0    0.3354\n",
       "0.0    0.3282\n",
       "Name: proportion, dtype: float64"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data.second_target.value_counts(normalize=True)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "# Importing the Library"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "from pytorch_tabular import TabularModel\n",
    "from pytorch_tabular.models import CategoryEmbeddingModelConfig\n",
    "from pytorch_tabular.config import DataConfig, OptimizerConfig, TrainerConfig, ExperimentConfig\n",
    "from pytorch_tabular.models.common.heads import LinearHeadConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "results = []"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Define the Configs\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "Collapsed": "false"
   },
   "outputs": [],
   "source": [
    "trainer_config = TrainerConfig(\n",
    "    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate\n",
    "    batch_size=1024,\n",
    "    max_epochs=100,\n",
    "    early_stopping=\"valid_loss\", # Monitor valid_loss for early stopping\n",
    "    early_stopping_mode = \"min\", # Set the mode as min because for val_loss, lower is better\n",
    "    early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating\n",
    "    checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n",
    "    load_best=True, # After training, load the best checkpoint\n",
    "#     accelerator=\"cpu\"\n",
    ")\n",
    "optimizer_config = OptimizerConfig()\n",
    "\n",
    "head_config = LinearHeadConfig(\n",
    "    layers=\"\", # No additional layer in head, just a mapping layer to output_dim\n",
    "    dropout=0.1,\n",
    "    initialization=\"kaiming\"\n",
    ").__dict__ # Convert to dict to pass to the model config (OmegaConf doesn't accept objects)\n",
    "\n",
    "model_config = CategoryEmbeddingModelConfig(\n",
    "    task=\"classification\",\n",
    "    layers=\"1024-512-512\",  # Number of nodes in each layer\n",
    "    activation=\"LeakyReLU\", # Activation between each layers\n",
    "    head = \"LinearHead\", #Linear Head\n",
    "    head_config = head_config, # Linear Head Config\n",
    "    learning_rate = 1e-3,\n",
    "    metrics=[\"f1_score\",\"accuracy\",\"auroc\"], \n",
    "    metrics_prob_input=[True, False, True] # f1_score needs probability scores, while accuracy doesn't\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Data Config For Each Target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "data_config_first_target = DataConfig(\n",
    "    target=['target'], #target should always be a list\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names\n",
    ")\n",
    "\n",
    "data_config_second_target = DataConfig(\n",
    "    target=['second_target'], #target should always be a list\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names\n",
    ")"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {
    "Collapsed": "false"
   },
   "source": [
    "## Training the Single-Target Model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "Collapsed": "false",
    "tags": []
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:02</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">110</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">140</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m110\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:02</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">126</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">541</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m126\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:02</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">130</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:507</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m130\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:02</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">150</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">591</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: CategoryEmbeddingModel \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m150\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:02</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">190</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">338</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m190\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:02</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">366</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">647</span><span style=\"font-weight: bold\">}</span> - INFO - Auto LR Find Started                        \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:02\u001b[0m,\u001b[1;36m366\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started                        \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d2cae94714354ea09299a9e44f238d89",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LR finder stopped early after 92 steps due to diverging loss.\n",
      "Learning rate set to 0.025118864315095822\n",
      "Restoring states from the checkpoint path at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_744b05b3-1568-4cf0-97ac-eaefc0c795f4.ckpt\n",
      "Restored all states from the checkpoint at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_744b05b3-1568-4cf0-97ac-eaefc0c795f4.ckpt\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:10</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">876</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">660</span><span style=\"font-weight: bold\">}</span> - INFO - Suggested LR: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.025118864315095822</span>. For plot\n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:10\u001b[0m,\u001b[1;36m876\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.025118864315095822\u001b[0m. For plot\n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:10</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">881</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">669</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:10\u001b[0m,\u001b[1;36m881\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ CategoryEmbeddingBackbone │  804 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer          │     68 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ head             │ LinearHead                │  1.0 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │  804 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │     68 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │  1.0 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 805 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 805 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 3                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 805 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 805 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d67a105c98184698acff2a3caab52b74",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:29</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">013</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">680</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:29\u001b[0m,\u001b[1;36m013\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:29</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">014</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1512</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:29\u001b[0m,\u001b[1;36m014\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "543ac972faed487c9fa81364d979f3f7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9484000205993652     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_auroc         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9717501997947693     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_f1_score       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9484000205993652     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.17070142924785614    </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.17070142924785614    </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9484000205993652    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_auroc        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9717501997947693    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_f1_score      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9484000205993652    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.17070142924785614   \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.17070142924785614   \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tabular_model = TabularModel(\n",
    "    data_config=data_config_first_target,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    ")\n",
    "\n",
    "tabular_model.fit(train=train, validation=val)\n",
    "\n",
    "result = tabular_model.evaluate(test)\n",
    "\n",
    "result = {k: float(v) for k,v in result[0].items()}\n",
    "result = pd.DataFrame({'f1':result['test_f1_score'],'auroc':result['test_auroc']},\n",
    "             index=['1st Target (single mode)'])\n",
    "results.append(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now train separately on the 2nd target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:30</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">958</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">140</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m958\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:30</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">972</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">541</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m972\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:30</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">976</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:507</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m976\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:30</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">996</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">591</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: CategoryEmbeddingModel \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:30\u001b[0m,\u001b[1;36m996\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:31</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">035</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">338</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:31\u001b[0m,\u001b[1;36m035\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:31</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">054</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">647</span><span style=\"font-weight: bold\">}</span> - INFO - Auto LR Find Started                        \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:31\u001b[0m,\u001b[1;36m054\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started                        \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ace5c6e2844d4ff58f2550a8e9da9ee4",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LR finder stopped early after 91 steps due to diverging loss.\n",
      "Learning rate set to 0.0002511886431509582\n",
      "Restoring states from the checkpoint path at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_cd3e2f01-7f2a-4103-ac17-9b4b4f8aeed1.ckpt\n",
      "Restored all states from the checkpoint at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_cd3e2f01-7f2a-4103-ac17-9b4b4f8aeed1.ckpt\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:38</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">580</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">660</span><span style=\"font-weight: bold\">}</span> - INFO - Suggested LR: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.0002511886431509582</span>. For    \n",
       "plot and detailed analysis, use `find_learning_rate` method.                                                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:38\u001b[0m,\u001b[1;36m580\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.0002511886431509582\u001b[0m. For    \n",
       "plot and detailed analysis, use `find_learning_rate` method.                                                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:38</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">584</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">669</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:38\u001b[0m,\u001b[1;36m584\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ CategoryEmbeddingBackbone │  804 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer          │     68 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ head             │ LinearHead                │  1.5 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │  804 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │     68 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │  1.5 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 806 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 806 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 3                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 806 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 806 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "483e9eeb0b754be9b966618d11ab2baa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:54</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">807</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">680</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:54\u001b[0m,\u001b[1;36m807\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:54</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">808</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1512</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:54\u001b[0m,\u001b[1;36m808\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0fad2da095be4448ae4507d193f28c4a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6808000206947327     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_auroc         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.8134375810623169     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_f1_score       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.6808000206947327     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.8243984580039978     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.8243984580039978     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6808000206947327    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_auroc        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.8134375810623169    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_f1_score      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.6808000206947327    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.8243984580039978    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.8243984580039978    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "tabular_model = TabularModel(\n",
    "    data_config=data_config_second_target,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    ")\n",
    "\n",
    "tabular_model.fit(train=train, validation=val)\n",
    "\n",
    "result = tabular_model.evaluate(test)\n",
    "\n",
    "result = {k: float(v) for k,v in result[0].items()}\n",
    "result = pd.DataFrame({'f1':result['test_f1_score'],'auroc':result['test_auroc']},\n",
    "             index=['2nd Target (single mode)'])\n",
    "\n",
    "results.append(result)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi-Target Training\n",
    "\n",
    "Instead of training one model for the first target, and another model for the second target, we can train a single model that will make a prediction for both targets.\n",
    "\n",
    "This is usually beneficial in reducing training time, but may also lead to better results since the model may have a better representation (embedding) by learning from multiple targets.\n",
    "\n",
    "To perform multi-target training, we only need to model the 'target' field in the data_config to include a list of multiple targets.\n",
    "\n",
    "Results are reported on the sum of all metrics (f1, AU-ROC, etc.), as well as a list of results for each target with the suffix '_n' (starting at n=1), for example, the f1 score for the 2nd target is test_f1_score_0."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:57</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">010</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">140</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m010\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:57</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">023</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">541</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m023\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:57</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">028</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:507</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m028\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:57</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">049</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">591</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: CategoryEmbeddingModel \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m049\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: CategoryEmbeddingModel \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:57</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">089</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">338</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m089\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:38:57</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">115</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">647</span><span style=\"font-weight: bold\">}</span> - INFO - Auto LR Find Started                        \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:38:57\u001b[0m,\u001b[1;36m115\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started                        \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c952afd65ca84628a3d1142ddc2e936e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "LR finder stopped early after 94 steps due to diverging loss.\n",
      "Learning rate set to 4.786300923226385e-05\n",
      "Restoring states from the checkpoint path at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_90994562-d7c8-47ac-974c-b9c85a34479e.ckpt\n",
      "Restored all states from the checkpoint at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_90994562-d7c8-47ac-974c-b9c85a34479e.ckpt\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:39:06</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">162</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">660</span><span style=\"font-weight: bold\">}</span> - INFO - Suggested LR: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">4.786300923226385e-05</span>. For    \n",
       "plot and detailed analysis, use `find_learning_rate` method.                                                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:39:06\u001b[0m,\u001b[1;36m162\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m4.786300923226385e-05\u001b[0m. For    \n",
       "plot and detailed analysis, use `find_learning_rate` method.                                                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:39:06</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">166</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">669</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:39:06\u001b[0m,\u001b[1;36m166\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type                      </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ CategoryEmbeddingBackbone │  804 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer          │     68 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ head             │ LinearHead                │  2.6 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType                     \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ CategoryEmbeddingBackbone │  804 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer          │     68 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ head             │ LinearHead                │  2.6 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss          │      0 │\n",
       "└───┴──────────────────┴───────────────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 807 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 807 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 3                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 807 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 807 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 3                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "49d66315050e49ed94823d49ea1ff8c0",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_epochs=100` reached.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">357</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">680</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:07\u001b[0m,\u001b[1;36m357\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:07</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">363</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1512</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:07\u001b[0m,\u001b[1;36m363\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b06cba39db194aaea56bc99fd05c334b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    1.5648000240325928     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_accuracy_0      </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.946399986743927     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_accuracy_1      </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.618399977684021     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_auroc         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    1.7456306219100952     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_auroc_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9698674082756042     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_auroc_1        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.7757631540298462     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_f1_score       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    1.5648000240325928     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_f1_score_0      </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.946399986743927     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_f1_score_1      </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.618399977684021     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    1.0632164478302002     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.16491228342056274    </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_1        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.8983041048049927     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   1.5648000240325928    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_accuracy_0     \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.946399986743927    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_accuracy_1     \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.618399977684021    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_auroc        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   1.7456306219100952    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_auroc_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9698674082756042    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_auroc_1       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.7757631540298462    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_f1_score      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   1.5648000240325928    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_f1_score_0     \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.946399986743927    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_f1_score_1     \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.618399977684021    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   1.0632164478302002    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.16491228342056274   \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_1       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.8983041048049927    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "data_config_multi = DataConfig(\n",
    "    target=['target','second_target'], #target should always be a list\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names\n",
    ")\n",
    "\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config_multi,\n",
    "    model_config=model_config,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_config,\n",
    ")\n",
    "\n",
    "tabular_model.fit(train=train, validation=val)\n",
    "\n",
    "result = tabular_model.evaluate(test)\n",
    "\n",
    "result = {k: float(v) for k,v in result[0].items()}\n",
    "result1 = pd.DataFrame({'f1':result['test_f1_score_0'],'auroc':result['test_auroc_0']},\n",
    "             index=['1st Target (multi-target mode)'])\n",
    "results.append(result1)\n",
    "result2 = pd.DataFrame({'f1':result['test_f1_score_1'],'auroc':result['test_auroc_1']},\n",
    "             index=['2nd Target (multi-target mode)'])\n",
    "results.append(result2)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<style type=\"text/css\">\n",
       "#T_71fc4_row0_col0, #T_71fc4_row0_col1 {\n",
       "  background-color: lightgreen;\n",
       "}\n",
       "</style>\n",
       "<table id=\"T_71fc4\">\n",
       "  <thead>\n",
       "    <tr>\n",
       "      <th class=\"blank level0\" >&nbsp;</th>\n",
       "      <th id=\"T_71fc4_level0_col0\" class=\"col_heading level0 col0\" >f1</th>\n",
       "      <th id=\"T_71fc4_level0_col1\" class=\"col_heading level0 col1\" >auroc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th id=\"T_71fc4_level0_row0\" class=\"row_heading level0 row0\" >1st Target (single mode)</th>\n",
       "      <td id=\"T_71fc4_row0_col0\" class=\"data row0 col0\" >0.948400</td>\n",
       "      <td id=\"T_71fc4_row0_col1\" class=\"data row0 col1\" >0.971750</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_71fc4_level0_row1\" class=\"row_heading level0 row1\" >2nd Target (single mode)</th>\n",
       "      <td id=\"T_71fc4_row1_col0\" class=\"data row1 col0\" >0.680800</td>\n",
       "      <td id=\"T_71fc4_row1_col1\" class=\"data row1 col1\" >0.813438</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_71fc4_level0_row2\" class=\"row_heading level0 row2\" >1st Target (multi-target mode)</th>\n",
       "      <td id=\"T_71fc4_row2_col0\" class=\"data row2 col0\" >0.946400</td>\n",
       "      <td id=\"T_71fc4_row2_col1\" class=\"data row2 col1\" >0.969867</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th id=\"T_71fc4_level0_row3\" class=\"row_heading level0 row3\" >2nd Target (multi-target mode)</th>\n",
       "      <td id=\"T_71fc4_row3_col0\" class=\"data row3 col0\" >0.618400</td>\n",
       "      <td id=\"T_71fc4_row3_col1\" class=\"data row3 col1\" >0.775763</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n"
      ],
      "text/plain": [
       "<pandas.io.formats.style.Styler at 0x315726880>"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "res_df = pd.concat(results)\n",
    "res_df.style.highlight_max(color=\"lightgreen\",axis=0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "In this run, we see that multi-target model performed on-par with the single-target on the 1st target, but slightly worse on the 2nd target. This may vary for this artificial dataset depending on random number generation, and in general multi-target may not outperform single-target variants. Additional tuning may be needed for multi-target, e.g. larger embedding size. "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Deep Learning Model\n",
    "\n",
    "We can also test whether a deeper model benefits from the shared embedding. In this case, we test Gandalf with multi-target classification, similar to our experiment above.\n",
    "\n",
    "Note: Without an accelerator (GPU), this training will take considerable time on a CPU"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "from pytorch_tabular.models import GANDALFConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:09</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">147</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">140</span><span style=\"font-weight: bold\">}</span> - INFO - Experiment Tracking is turned off           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m147\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m140\u001b[0m\u001b[1m}\u001b[0m - INFO - Experiment Tracking is turned off           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Seed set to 42\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:09</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">160</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">541</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the DataLoaders                   \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m160\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m541\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the DataLoaders                   \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:09</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">165</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_datamodul<span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">e:507</span><span style=\"font-weight: bold\">}</span> - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m165\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_datamodul\u001b[1;92me:507\u001b[0m\u001b[1m}\u001b[0m - INFO - Setting up the datamodule for          \n",
       "classification task                                                                                                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:09</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">188</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">591</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Model: GANDALFModel           \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m188\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m591\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Model: GANDALFModel           \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:09</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">212</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">338</span><span style=\"font-weight: bold\">}</span> - INFO - Preparing the Trainer                       \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m212\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m338\u001b[0m\u001b[1m}\u001b[0m - INFO - Preparing the Trainer                       \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "GPU available: True (mps), used: True\n",
      "TPU available: False, using: 0 TPU cores\n",
      "IPU available: False, using: 0 IPUs\n",
      "HPU available: False, using: 0 HPUs\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:09</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">240</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">647</span><span style=\"font-weight: bold\">}</span> - INFO - Auto LR Find Started                        \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:09\u001b[0m,\u001b[1;36m240\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m647\u001b[0m\u001b[1m}\u001b[0m - INFO - Auto LR Find Started                        \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:653: Checkpoint directory /Users/yony/Code/pytorch-tabular-public/docs/tutorials/saved_models exists and is not empty.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py:298: The number of training batches (6) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.\n",
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3d222f1be7284f9aae6216f29e17ea4b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Finding best initial lr:   0%|          | 0/100 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "`Trainer.fit` stopped: `max_steps=100` reached.\n",
      "Learning rate set to 0.10964781961431852\n",
      "Restoring states from the checkpoint path at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_92ae2770-372c-495e-9410-7ff7723d419e.ckpt\n",
      "Restored all states from the checkpoint at /Users/yony/Code/pytorch-tabular-public/docs/tutorials/.lr_find_92ae2770-372c-495e-9410-7ff7723d419e.ckpt\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:22</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">112</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">660</span><span style=\"font-weight: bold\">}</span> - INFO - Suggested LR: <span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">0.10964781961431852</span>. For plot \n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:22\u001b[0m,\u001b[1;36m112\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m660\u001b[0m\u001b[1m}\u001b[0m - INFO - Suggested LR: \u001b[1;36m0.10964781961431852\u001b[0m. For plot \n",
       "and detailed analysis, use `find_learning_rate` method.                                                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:22</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">115</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">669</span><span style=\"font-weight: bold\">}</span> - INFO - Training Started                            \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:22\u001b[0m,\u001b[1;36m115\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m669\u001b[0m\u001b[1m}\u001b[0m - INFO - Training Started                            \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\">   </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Name             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Type             </span>┃<span style=\"color: #800080; text-decoration-color: #800080; font-weight: bold\"> Params </span>┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 0 </span>│ _backbone        │ GANDALFBackbone  │  9.6 K │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 1 </span>│ _embedding_layer │ Embedding1dLayer │     68 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 2 </span>│ _head            │ Sequential       │     90 │\n",
       "│<span style=\"color: #7f7f7f; text-decoration-color: #7f7f7f\"> 3 </span>│ loss             │ CrossEntropyLoss │      0 │\n",
       "└───┴──────────────────┴──────────────────┴────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━┳━━━━━━━━┓\n",
       "┃\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mName            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mType            \u001b[0m\u001b[1;35m \u001b[0m┃\u001b[1;35m \u001b[0m\u001b[1;35mParams\u001b[0m\u001b[1;35m \u001b[0m┃\n",
       "┡━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━╇━━━━━━━━┩\n",
       "│\u001b[2m \u001b[0m\u001b[2m0\u001b[0m\u001b[2m \u001b[0m│ _backbone        │ GANDALFBackbone  │  9.6 K │\n",
       "│\u001b[2m \u001b[0m\u001b[2m1\u001b[0m\u001b[2m \u001b[0m│ _embedding_layer │ Embedding1dLayer │     68 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m2\u001b[0m\u001b[2m \u001b[0m│ _head            │ Sequential       │     90 │\n",
       "│\u001b[2m \u001b[0m\u001b[2m3\u001b[0m\u001b[2m \u001b[0m│ loss             │ CrossEntropyLoss │      0 │\n",
       "└───┴──────────────────┴──────────────────┴────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Trainable params</span>: 9.8 K                                                                                            \n",
       "<span style=\"font-weight: bold\">Non-trainable params</span>: 0                                                                                            \n",
       "<span style=\"font-weight: bold\">Total params</span>: 9.8 K                                                                                                \n",
       "<span style=\"font-weight: bold\">Total estimated model params size (MB)</span>: 0                                                                          \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mTrainable params\u001b[0m: 9.8 K                                                                                            \n",
       "\u001b[1mNon-trainable params\u001b[0m: 0                                                                                            \n",
       "\u001b[1mTotal params\u001b[0m: 9.8 K                                                                                                \n",
       "\u001b[1mTotal estimated model params size (MB)\u001b[0m: 0                                                                          \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "833c09b04f5a4841a69215c4161a94ec",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:28</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">324</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">680</span><span style=\"font-weight: bold\">}</span> - INFO - Training the model completed                \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:28\u001b[0m,\u001b[1;36m324\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m680\u001b[0m\u001b[1m}\u001b[0m - INFO - Training the model completed                \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">2024</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">07</span>-<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">10</span> <span style=\"color: #00ff00; text-decoration-color: #00ff00; font-weight: bold\">11:40:28</span>,<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">326</span> - <span style=\"font-weight: bold\">{</span>pytorch_tabular.tabular_model:<span style=\"color: #008080; text-decoration-color: #008080; font-weight: bold\">1512</span><span style=\"font-weight: bold\">}</span> - INFO - Loading the best model                     \n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1;36m2024\u001b[0m-\u001b[1;36m07\u001b[0m-\u001b[1;36m10\u001b[0m \u001b[1;92m11:40:28\u001b[0m,\u001b[1;36m326\u001b[0m - \u001b[1m{\u001b[0mpytorch_tabular.tabular_model:\u001b[1;36m1512\u001b[0m\u001b[1m}\u001b[0m - INFO - Loading the best model                     \n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "8db6da2e21c24377a31514cc5a245919",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Output()"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/Users/yony/Code/pytorch-tabular-public/.venv/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=9` in the `DataLoader` to improve performance.\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃<span style=\"font-weight: bold\">        Test metric        </span>┃<span style=\"font-weight: bold\">       DataLoader 0        </span>┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_accuracy       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     1.210800051689148     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_accuracy_0      </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9120000004768372     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_accuracy_1      </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.298799991607666     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_auroc         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    1.4514379501342773     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_auroc_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9510558843612671     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_auroc_1        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.500382125377655     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">       test_f1_score       </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     1.210800051689148     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_f1_score_0      </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.9120000004768372     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">      test_f1_score_1      </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     0.298799991607666     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">         test_loss         </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    1.3304287195205688     </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_0        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">    0.22366906702518463    </span>│\n",
       "│<span style=\"color: #008080; text-decoration-color: #008080\">        test_loss_1        </span>│<span style=\"color: #800080; text-decoration-color: #800080\">     1.106759786605835     </span>│\n",
       "└───────────────────────────┴───────────────────────────┘\n",
       "</pre>\n"
      ],
      "text/plain": [
       "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
       "┃\u001b[1m \u001b[0m\u001b[1m       Test metric       \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m      DataLoader 0       \u001b[0m\u001b[1m \u001b[0m┃\n",
       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_accuracy      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    1.210800051689148    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_accuracy_0     \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9120000004768372    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_accuracy_1     \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.298799991607666    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_auroc        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   1.4514379501342773    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_auroc_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9510558843612671    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_auroc_1       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.500382125377655    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m      test_f1_score      \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    1.210800051689148    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_f1_score_0     \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.9120000004768372    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m     test_f1_score_1     \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    0.298799991607666    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m        test_loss        \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   1.3304287195205688    \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_0       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m   0.22366906702518463   \u001b[0m\u001b[35m \u001b[0m│\n",
       "│\u001b[36m \u001b[0m\u001b[36m       test_loss_1       \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m    1.106759786605835    \u001b[0m\u001b[35m \u001b[0m│\n",
       "└───────────────────────────┴───────────────────────────┘\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"></pre>\n"
      ],
      "text/plain": []
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>f1</th>\n",
       "      <th>auroc</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>1st Target (single mode)</th>\n",
       "      <td>0.9484</td>\n",
       "      <td>0.971750</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2nd Target (single mode)</th>\n",
       "      <td>0.6808</td>\n",
       "      <td>0.813438</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1st Target (multi-target mode)</th>\n",
       "      <td>0.9464</td>\n",
       "      <td>0.969867</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2nd Target (multi-target mode)</th>\n",
       "      <td>0.6184</td>\n",
       "      <td>0.775763</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1st Target (Gandalf, multi-target mode)</th>\n",
       "      <td>0.9120</td>\n",
       "      <td>0.951056</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2nd Target (Gandalf, multi-target mode)</th>\n",
       "      <td>0.2988</td>\n",
       "      <td>0.500382</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             f1     auroc\n",
       "1st Target (single mode)                 0.9484  0.971750\n",
       "2nd Target (single mode)                 0.6808  0.813438\n",
       "1st Target (multi-target mode)           0.9464  0.969867\n",
       "2nd Target (multi-target mode)           0.6184  0.775763\n",
       "1st Target (Gandalf, multi-target mode)  0.9120  0.951056\n",
       "2nd Target (Gandalf, multi-target mode)  0.2988  0.500382"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_config_2nd = DataConfig(\n",
    "    target=['target', 'second_target'],\n",
    "    continuous_cols=num_col_names,\n",
    "    categorical_cols=cat_col_names,\n",
    ")\n",
    "\n",
    "trainer_gl_config = TrainerConfig(\n",
    "    auto_lr_find=True, # Runs the LRFinder to automatically derive a learning rate\n",
    "    batch_size=1024,\n",
    "    max_epochs=50,\n",
    "    early_stopping=\"valid_loss\", # Monitor valid_loss for early stopping\n",
    "    early_stopping_mode = \"min\", # Set the mode as min because for val_loss, lower is better\n",
    "    early_stopping_patience=5, # No. of epochs of degradation training will wait before terminating\n",
    "    checkpoints=\"valid_loss\", # Save best checkpoint monitoring val_loss\n",
    "    load_best=True, # After training, load the best checkpoint\n",
    "    # accelerator=\"cpu\"\n",
    ")\n",
    "\n",
    "model_config_gandalf = GANDALFConfig(task='classification',\n",
    "                            metrics=[\"f1_score\",\"accuracy\",\"auroc\"], \n",
    "                            metrics_prob_input=[True, False, True], # f1_score needs probability scores, while accuracy doesn't,\n",
    "                            gflu_stages=6, gflu_dropout=0.2\n",
    ")\n",
    "\n",
    "\n",
    "tabular_model = TabularModel(\n",
    "    data_config=data_config_2nd,\n",
    "    model_config=model_config_gandalf,\n",
    "    optimizer_config=optimizer_config,\n",
    "    trainer_config=trainer_gl_config,\n",
    ")\n",
    "\n",
    "tabular_model.fit(train=train, validation=val)\n",
    "\n",
    "result = tabular_model.evaluate(test)\n",
    "\n",
    "result = {k: float(v) for k,v in result[0].items()}\n",
    "result1 = pd.DataFrame({'f1':result['test_f1_score_0'],'auroc':result['test_auroc_0']},\n",
    "             index=['1st Target (Gandalf, multi-target mode)'])\n",
    "results.append(result1)\n",
    "result2 = pd.DataFrame({'f1':result['test_f1_score_1'],'auroc':result['test_auroc_1']},\n",
    "             index=['2nd Target (Gandalf, multi-target mode)'])\n",
    "results.append(result2)\n",
    "\n",
    "pd.concat(results)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Results Summary\n",
    "\n",
    "Here we see that the deeper Gandalf model with multi-target performed worse than either variant of the Category Embedding Model. As before, additional hyperparameter tuning may be required for a fair comparison."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.18"
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
